In [9]:
from sklearn.cluster import KMeans
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import style
style.use("ggplot")
%matplotlib inline

K-Means Clustering Example

In this example notebook, you will see how to implement K-Means Clustering in Python using Scikit-Learn and Pandas. Adapted from https://pythonprogramming.net/flat-clustering-machine-learning-python-scikit-learn/

Step 1: Get Data:

The first step is to prepare or generate the data. In this dataset, the observations only have two features, but K-Means can be used with any number of features. Since this is an unsupervised example, it is not necessary to have a "target" column.


In [24]:
data = pd.DataFrame([[1, 2],
              [5, 8],
              [1.5, 1.8],
              [8, 8],
              [1, 0.6],
              [9, 11]], columns=['x','y'])
print( data )


Out[24]:
x y
0 1.0 2.0
1 5.0 8.0
2 1.5 1.8
3 8.0 8.0
4 1.0 0.6
5 9.0 11.0

Step 2: Build the Model:

Much like the supervised models, you first create the model then call the .fit() method using your data source. The model is now populated with both your centroids and labels. These can be accessed via the .cluster_centers_ and labels_ properties respectively.

You can view the complete documentation here: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html

K-Means also has a .predict() method which can be used to predict the label for an observation.


In [12]:
kmeans = KMeans(n_clusters=2).fit(data)

centroids = kmeans.cluster_centers_
labels = kmeans.labels_

print(centroids)
print(labels)


[[ 1.16666667  1.46666667]
 [ 7.33333333  9.        ]]
[0 1 0 1 0 1]

Visualizing the Clusters

The code below visualizes the clusters.


In [35]:
data['labels'] = labels

#plt.plot(data, colors[data['labels'], markersize = 10)

group1 = data[data['labels']==1].plot( kind='scatter', x='x', y='y', color='DarkGreen', label="Group 1" )
group2 = data[data['labels']==0].plot( kind='scatter', x='x', y='y', color='Brown', ax=group1, label="Group 2" )
group1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05),
          ncol=3, fancybox=True, shadow=True)
plt.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=150, linewidths = 5, zorder = 10)

plt.show()